import scipy.io
import numpy as np
import torch
from torch_geometric.loader import DataLoader
import os
import sys
sys.path.append(os.getcwd())
from generate_data.power_impedance_dataset import VoltageGraphDataset
from generate_data.case_generation import case_generation
from pypower import loadcase
import utils.config as config
import math
import h5py
from generate_data.graphProcess import convert_to_pyg_data
import pickle
import multiprocessing
import matplotlib.pyplot as plt

def process_field(field_name, mat_file_path):
    # 在每个进程中独立打开文件
    with h5py.File(mat_file_path, 'r') as f:
        if 'results_data' not in f:
            raise KeyError("'results_data' 未在 .mat 文件中找到。请检查文件内容。")
        
        results_data = f['results_data']
        dataset = results_data[field_name]
        
        data = []
        for i in range(dataset.shape[1]):
            value = dataset[0][i]
            if isinstance(value, h5py.Reference):
                value = dereference(value,f).T
            data.append(value)
        
    return (field_name, data)


def load_data():
    case_generation()
    # 加载 .mat 文件
    mat_file_path = config.mat_file_path  # 请确保路径正确
    with h5py.File(mat_file_path, 'r') as f:
        if 'results_data' not in f:
            raise KeyError("'results_data' 未在 .mat 文件中找到。请检查文件内容。")

        results_data = f['results_data']
        print(results_data.keys())

        # 使用 multiprocessing.Pool 来并行处理每个字段
        pool = multiprocessing.Pool(processes=8)  # 启动 8 个进程

        # 在每个进程中独立打开文件并读取数据
        results = pool.starmap(process_field, 
                               [(field_name, mat_file_path) for field_name in results_data.keys()])

        # 将返回的结果整理到字典中
        results_dict = {field_name: data for field_name, data in results}
        
        # 获取数据量
        data_num = len(results_dict['voltage_magnitude'])
        print(f'data_num: {data_num}')


    group_num = config.group_num

    # 初始化列表来存储拼接后的数据
    voltage_data = []
    power_data = []
    branch_id_data = []
    graph_data_list = []

    real_voltage_angle_list = []
    real_voltage_magnitude_list = []
    print(f"config.bus_slack_VA")
    slack_VA = (config.bus_slack_VA * math.pi / 180 - config.VaLb)/(config.VaUb - config.VaLb)
    print(f"slack_VA:{slack_VA}")

    # 遍历 results_dict 中的每个元素，按要求进行拼接
    for idx in range(data_num):
        # 提取每个字段的内容
        voltage_magnitude = results_dict['voltage_magnitude'][idx].squeeze()
        voltage_angle = results_dict['voltage_angle'][idx].squeeze()
        p_sampled = results_dict['p_sampled'][idx].squeeze()
        q_sampled = results_dict['q_sampled'][idx].squeeze()
        branch_id = results_dict['id'][idx].squeeze(1)
        adj_matrix = results_dict['adj_matrix'][idx]
        node_features = results_dict['node_features'][idx]
        edge_features_norm = results_dict['edge_features_norm'][idx]

        
        voltage_angle = voltage_angle * math.pi / 180
        real_voltage_angle = voltage_angle
        voltage_angle = (voltage_angle - config.VaLb) / (config.VaUb - config.VaLb)
        real_voltage_magnitude = voltage_magnitude
        voltage_magnitude = (voltage_magnitude - config.VmLb) / (config.VmUb - config.VmLb)

        voltage_angle = np.delete(voltage_angle, config.bus_slack, axis=0)

        # 拼接 voltage_magnitude 和 voltage_angle 在 dim2 上
        voltage_combined = np.concatenate((voltage_magnitude, voltage_angle))
        
        # 拼接 p_sampled 和 q_sampled 在 dim2 上
        power_combined = np.column_stack((p_sampled, q_sampled))
        
        
        # 全部转置
        voltage_combined = voltage_combined.T  # (2, 9)
        power_combined = power_combined.T      # (2, 9)

        graph_data = convert_to_pyg_data(adj_matrix, node_features, edge_features_norm)

        # 重新调整数据形状，确保它们按 (9, 2) 的形状保持一致

        voltage_data.append(voltage_combined)
        power_data.append(power_combined)
        branch_id_data.append(int(branch_id))
        graph_data_list.append(graph_data)
        real_voltage_magnitude_list.append(real_voltage_magnitude)
        real_voltage_angle_list.append(real_voltage_angle)

        if idx % 10000 == 1:
            print(f'processing:{idx}/{data_num}')
    

    # 转换数据形状
    voltage_data = np.array(voltage_data)  # 
    power_data = np.array(power_data) / config.baseMVA     # 
    branch_id_data = np.array(branch_id_data)
    real_voltage_magnitude_list = np.array(real_voltage_magnitude_list)
    real_voltage_angle_list = np.array(real_voltage_angle_list)

    # 打印最终的数据形状
    print(f"\nFinal voltage_data shape: {voltage_data.shape}")
    print(f"Final power_data shape: {power_data.shape}")
    print("group_num:{}".format(group_num))
    print(f'branch_data.shape:{branch_id_data.shape}')
    print(f'real_voltage_magnitude_list.shape:{real_voltage_magnitude_list.shape}')


    # 创建新的索引顺序，每group_num个一组
    new_order = []
    for i in range(group_num):
        new_order.extend(range(i, data_num, group_num))

    # 使用新的索引顺序重新排列数据
    branch_id_rearranged = branch_id_data[new_order]
    voltage_data_rearranged = voltage_data[new_order]
    power_data_rearranged = power_data[new_order]
    graph_data_list_rearranged = [graph_data_list[idx] for idx in new_order]
    real_voltage_magnitude_list_arranged = real_voltage_magnitude_list[new_order]
    real_voltage_angle_list_arranged = real_voltage_angle_list[new_order]

    train_graph = graph_data_list_rearranged[:int(data_num * 0.8)]
    train_voltage = voltage_data_rearranged[:int(data_num * 0.8)]
    train_power = power_data_rearranged[:int(data_num * 0.8)]

    test_graph = graph_data_list_rearranged[int(data_num * 0.8):]
    test_voltage = voltage_data_rearranged[int(data_num * 0.8):]
    test_power = power_data_rearranged[int(data_num * 0.8):]
    test_real_voltage_magnitude = real_voltage_magnitude_list_arranged[int(data_num * 0.8):]
    test_real_voltage_angel = real_voltage_angle_list_arranged[int(data_num * 0.8):]
    config.branch_test_id = branch_id_rearranged[int(data_num * 0.8):]


    train_voltage_tensor = torch.tensor(train_voltage, dtype=torch.float32)
    test_voltage_tensor = torch.tensor(test_voltage, dtype=torch.float32)


    # 创建训练集和测试集的 Dataset
    train_dataset = VoltageGraphDataset(train_graph, train_voltage_tensor)
    test_dataset = VoltageGraphDataset(test_graph, test_voltage_tensor)

    # 创建 DataLoader
    print("create trian and test loaders")
    # train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers = 4)  # 这里的batch_size可以根据需要调整
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers = 4)  # 测试集通常不需要shuffle


    print(f'test_voltage_tensor.shape:{test_voltage_tensor.shape}')
    test_voltage_numpy = test_voltage_tensor.clone().numpy()
    real_test_voltage_M = test_voltage_numpy[:, 0:config.Nbus]
    real_test_voltage_A = test_voltage_numpy[:, config.Nbus:]



    print(f'real_test_voltage_M.sahpe:{real_test_voltage_M.shape};config.VmUb.shape:{config.VmUb.shape}')
    real_test_voltage_M = real_test_voltage_M * (config.VmUb - config.VmLb) + config.VmLb
    real_test_voltage_A = np.insert(real_test_voltage_A, config.bus_slack, values=slack_VA, axis=1)  # 注意：角度为0在换算为归一化之后是0.5，所以这里需要填充的是0.5，而不是0
    real_test_voltage_A = real_test_voltage_A * (config.VaUb - config.VaLb) + config.VaLb

   
    config.Ntest = int(data_num * 0.2)
    config.test_Pd = test_power[:, 0, :].squeeze()
    config.test_Qd = test_power[:, 1, :].squeeze()
    config.real_test_VA = real_test_voltage_A
    config.real_test_VM = real_test_voltage_M
    print(f'data_num:{data_num};Ntest:{config.Ntest}')

    config.real_test_V = real_test_voltage_M * np.exp(1j * real_test_voltage_A)

    # 将 config 对象中的变量保存到一个字典中
    config_dict = {
        'Ntest': config.Ntest,
        'test_Pd': config.test_Pd,
        'test_Qd': config.test_Qd,
        'real_test_VA': config.real_test_VA,
        'real_test_VM': config.real_test_VM,
        'real_test_V': config.real_test_V,
        'branch_test_id': config.branch_test_id
    }

     # 使用 pickle 保存到文件
    with open(config.piclkle_config_dataset_path, 'wb') as f:
        pickle.dump(config_dict, f)
        # 保存 Dataset
    with open(config.pickle_train_dataset_path, 'wb') as f:
        pickle.dump(train_dataset, f)
    with open(config.pickle_test_dataset_path, 'wb') as f:
        pickle.dump(test_dataset, f)

    return True



def dereference(value, f):
    """
    解引用 HDF5 的引用对象并返回实际数据。如果解引用结果是稀疏矩阵，则处理。
    """
    if isinstance(value, h5py.h5r.Reference):
        # 解引用，获取引用指向的 Dataset 数据
        resolved_value = f[value]
        if isinstance(resolved_value, h5py.Group):
            # 如果是 Group，可能是稀疏矩阵的存储结构
            if 'data' in resolved_value and 'ir' in resolved_value and 'jc' in resolved_value:
                # 提取稀疏矩阵的 data、ir 和 jc
                data = resolved_value['data'][()]
                ir = resolved_value['ir'][()]
                jc = resolved_value['jc'][()]
                # 转换为 scipy.sparse.csc_matrix
                sparse_matrix = scipy.sparse.csc_matrix((data, ir, jc))
                return sparse_matrix
            else:
                raise ValueError("Group 不是稀疏矩阵的存储结构。")
        elif isinstance(resolved_value, h5py.Dataset):
            # 如果是 Dataset，直接返回数据
            return resolved_value[()]
        else:
            raise TypeError(f"解引用结果类型不支持: {type(resolved_value)}")
    elif isinstance(value, h5py.Dataset):
        # 如果直接是 Dataset，不是引用，直接获取数据
        return value[()]
    else:
        # 如果是其他类型，直接返回数据
        return value



import pickle

def reload_config():
    case_generation()
    # 加载保存的 train_dataset 和 test_dataset
    with open(config.pickle_train_dataset_path, 'rb') as f:
        train_dataset = pickle.load(f)
    
    # 打印 train_dataset 中数据的形状大小
    train_dataset.print_shapes()
    
    # with open(config.pickle_test_dataset_path, 'rb') as f:
    #     test_dataset = pickle.load(f)

    # 根据加载的 dataset 重新设置 config 中的变量
    # train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
    # test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

    with open(config.piclkle_config_dataset_path, 'rb') as f:
        config_dict = pickle.load(f)

    # 恢复变量到 config 中
    config.Ntest = config_dict['Ntest']
    config.test_Pd = config_dict['test_Pd']
    config.test_Qd = config_dict['test_Qd']
    config.real_test_VA = config_dict['real_test_VA']
    config.real_test_VM = config_dict['real_test_VM']
    config.real_test_V = config_dict['real_test_V']
    config.branch_test_id = config_dict['branch_test_id']
    # print(f'Ntest:{config.Ntest};len(branch_test_id):{len(config.branch_test_id)}')

    # print(config.real_test_V[-1])

    plot_complex_voltage_data(config.real_test_V)

    print("Config loaded successfully!")
    

    return train_dataset


def plot_complex_voltage_data(data, output_filename='voltage_scatter_plot.png'):
    """
    绘制复数电压数据的散点图并保存到文件。

    参数:
    - data (np.array): 形状为 (4620, 118) 的复数数组，表示电压数据。
    - output_filename (str): 输出文件的名称，默认是 'voltage_scatter_plot.png'。

    返回:
    - str: 保存的图像文件路径。
    """
    # 拆解成实部和虚部
    real_part = data.real
    imag_part = data.imag

    # 创建一个散点图
    plt.figure(figsize=(10, 8))
    plt.scatter(real_part.flatten(), imag_part.flatten(), s=1, alpha=0.5)

    # 设置标题和轴标签
    plt.title('Scatter Plot of Complex Voltage Data')
    plt.xlabel('Real Part')
    plt.ylabel('Imaginary Part')

    # 保存图像到文件
    plt.savefig(output_filename)

    # 返回保存的文件路径
    return output_filename